import Kernel, except: [round: 1]

defmodule Float do
  @moduledoc """
  Functions for working with floating-point numbers.
  """

  import Bitwise

  @power_of_2_to_52 4_503_599_627_370_496
  @precision_range 0..15
  @type precision_range :: 0..15

  @doc """
  Computes `base` raised to power of `exponent`.
  """
  @doc since: "1.12.0"
  @spec pow(float, number) :: float
  def pow(base, exponent) when is_float(base) and is_number(exponent),
    do: :math.pow(base, exponent)

  @doc """
  Parses a binary into a float.
  """
  @spec parse(binary) :: {float, binary} | :error
  def parse("-" <> binary) do
    case parse_unsigned(binary) do
      :error -> :error
      {number, remainder} -> {-number, remainder}
    end
  end

  def parse("+" <> binary) do
    parse_unsigned(binary)
  end

  def parse(binary) do
    parse_unsigned(binary)
  end

  defp parse_unsigned(<<digit, rest::binary>>) when digit in ?0..?9,
    do: parse_unsigned(rest, false, false, <<digit>>)

  defp parse_unsigned(binary) when is_binary(binary), do: :error

  defp parse_unsigned(<<digit, rest::binary>>, dot?, e?, acc) when digit in ?0..?9,
    do: parse_unsigned(rest, dot?, e?, <<acc::binary, digit>>)

  defp parse_unsigned(<<?., digit, rest::binary>>, false, false, acc) when digit in ?0..?9,
    do: parse_unsigned(rest, true, false, <<acc::binary, ?., digit>>)

  defp parse_unsigned(<<exp_marker, digit, rest::binary>>, dot?, false, acc)
       when exp_marker in 'eE' and digit in ?0..?9,
       do: parse_unsigned(rest, true, true, <<add_dot(acc, dot?)::binary, ?e, digit>>)

  defp parse_unsigned(<<exp_marker, sign, digit, rest::binary>>, dot?, false, acc)
       when exp_marker in 'eE' and sign in '-+' and digit in ?0..?9,
       do: parse_unsigned(rest, true, true, <<add_dot(acc, dot?)::binary, ?e, sign, digit>>)

  defp parse_unsigned(rest, dot?, _e?, acc),
    do: {:erlang.binary_to_float(add_dot(acc, dot?)), rest}

  defp add_dot(acc, true), do: acc
  defp add_dot(acc, false), do: acc <> ".0"

  @doc """
  Rounds a float to the largest number less than or equal to `num`.
  """
  @spec floor(float, precision_range) :: float
  def floor(number, precision \\ 0)

  def floor(number, 0) when is_float(number) do
    :math.floor(number)
  end

  def floor(number, precision) when is_float(number) and precision in @precision_range do
    round(number, precision, :floor)
  end

  def floor(number, precision) when is_float(number) do
    raise ArgumentError, invalid_precision_message(precision)
  end

  @doc """
  Rounds a float to the smallest integer greater than or equal to `num`.
  """
  @spec ceil(float, precision_range) :: float
  def ceil(number, precision \\ 0)

  def ceil(number, 0) when is_float(number) do
    :math.ceil(number)
  end

  def ceil(number, precision) when is_float(number) and precision in @precision_range do
    round(number, precision, :ceil)
  end

  def ceil(number, precision) when is_float(number) do
    raise ArgumentError, invalid_precision_message(precision)
  end

  @doc """
  Rounds a floating-point value to an arbitrary number of fractional
  digits (between 0 and 15).
  """
  @spec round(float, precision_range) :: float
  # This implementation is slow since it relies on big integers.
  # Faster implementations are available on more recent papers
  # and could be implemented in the future.
  def round(float, precision \\ 0)

  def round(float, 0) when is_float(float) do
    float |> :erlang.round() |> :erlang.float()
  end

  def round(float, precision) when is_float(float) and precision in @precision_range do
    round(float, precision, :half_up)
  end

  def round(float, precision) when is_float(float) do
    raise ArgumentError, invalid_precision_message(precision)
  end

  defp round(0.0 = num, _precision, _rounding), do: num

  defp round(float, precision, rounding) do
    <<sign::1, exp::11, significant::52-bitstring>> = <<float::float>>
    {num, count} = decompose(significant, 1)
    count = count - exp + 1023

    cond do
      # Precision beyond 15 digits
      count >= 104 ->
        case rounding do
          :ceil when sign === 0 -> 1 / power_of_10(precision)
          :floor when sign === 1 -> -1 / power_of_10(precision)
          _ -> 0.0
        end

      # We are asking more precision than we have
      count <= precision ->
        float

      true ->
        # Difference in precision between float and asked precision
        # We subtract 1 because we need to calculate the remainder too
        diff = count - precision - 1

        # Get up to latest so we calculate the remainder
        power_of_10 = power_of_10(diff)

        # Convert the numerand to decimal base
        num = num * power_of_5(count)

        # Move to the given precision - 1
        num = div(num, power_of_10)
        div = div(num, 10)
        num = rounding(rounding, sign, num, div)

        # Convert back to float without loss
        # https://www.exploringbinary.com/correct-decimal-to-floating-point-using-big-integers/
        den = power_of_10(precision)
        boundary = den <<< 52

        cond do
          num == 0 ->
            0.0

          num >= boundary ->
            {den, exp} = scale_down(num, boundary, 52)
            decimal_to_float(sign, num, den, exp)

          true ->
            {num, exp} = scale_up(num, boundary, 52)
            decimal_to_float(sign, num, den, exp)
        end
    end
  end

  defp decompose(significant, initial) do
    decompose(significant, 1, 0, initial)
  end

  defp decompose(<<1::1, bits::bitstring>>, count, last_count, acc) do
    decompose(bits, count + 1, count, (acc <<< (count - last_count)) + 1)
  end

  defp decompose(<<0::1, bits::bitstring>>, count, last_count, acc) do
    decompose(bits, count + 1, last_count, acc)
  end

  defp decompose(<<>>, _count, last_count, acc) do
    {acc, last_count}
  end

  defp scale_up(num, boundary, exp) when num >= boundary, do: {num, exp}
  defp scale_up(num, boundary, exp), do: scale_up(num <<< 1, boundary, exp - 1)

  defp scale_down(num, den, exp) do
    new_den = den <<< 1

    if num < new_den do
      {den >>> 52, exp}
    else
      scale_down(num, new_den, exp + 1)
    end
  end

  defp decimal_to_float(sign, num, den, exp) do
    quo = div(num, den)
    rem = num - quo * den

    tmp =
      case den >>> 1 do
        den when rem > den -> quo + 1
        den when rem < den -> quo
        _ when (quo &&& 1) === 1 -> quo + 1
        _ -> quo
      end

    tmp = tmp - @power_of_2_to_52
    <<tmp::float>> = <<sign::1, exp + 1023::11, tmp::52>>
    tmp
  end

  defp rounding(:floor, 1, _num, div), do: div + 1
  defp rounding(:ceil, 0, _num, div), do: div + 1

  defp rounding(:half_up, _sign, num, div) do
    case rem(num, 10) do
      rem when rem < 5 -> div
      rem when rem >= 5 -> div + 1
    end
  end

  defp rounding(_, _, _, div), do: div

  Enum.reduce(0..104, 1, fn x, acc ->
    defp power_of_10(unquote(x)), do: unquote(acc)
    acc * 10
  end)

  Enum.reduce(0..104, 1, fn x, acc ->
    defp power_of_5(unquote(x)), do: unquote(acc)
    acc * 5
  end)

  @doc """
  Returns a pair of integers whose ratio is exactly equal
  to the original float and with a positive denominator.
  """
  @doc since: "1.4.0"
  @spec ratio(float) :: {integer, pos_integer}
  def ratio(0.0), do: {0, 1}

  def ratio(float) when is_float(float) do
    <<sign::1, exp::11, mantissa::52>> = <<float::float>>

    {num, den_exp} =
      if exp != 0 do
        # Floats are expressed like this:
        # (2**52 + mantissa) * 2**(-52 + exp - 1023)
        #
        # We compute the root factors of the mantissa so we have this:
        # (2**52 + mantissa * 2**count) * 2**(-52 + exp - 1023)
        {mantissa, count} = root_factors(mantissa, 0)

        # Now we can move the count around so we have this:
        # (2**(52-count) + mantissa) * 2**(count + -52 + exp - 1023)
        if mantissa == 0 do
          {1, exp - 1023}
        else
          num = (1 <<< (52 - count)) + mantissa
          den_exp = count - 52 + exp - 1023
          {num, den_exp}
        end
      else
        # Subnormals are expressed like this:
        # (mantissa) * 2**(-52 + 1 - 1023)
        #
        # So we compute it to this:
        # (mantissa * 2**(count)) * 2**(-52 + 1 - 1023)
        #
        # Which becomes:
        # mantissa * 2**(count-1074)
        root_factors(mantissa, -1074)
      end

    if den_exp > 0 do
      {sign(sign, num <<< den_exp), 1}
    else
      {sign(sign, num), 1 <<< -den_exp}
    end
  end

  defp root_factors(mantissa, count) when mantissa != 0 and (mantissa &&& 1) == 0,
    do: root_factors(mantissa >>> 1, count + 1)

  defp root_factors(mantissa, count),
    do: {mantissa, count}

  @compile {:inline, sign: 2}
  defp sign(0, num), do: num
  defp sign(1, num), do: -num

  @doc """
  Returns a charlist which corresponds to the text representation
  of the given float.
  """
  @spec to_charlist(float) :: charlist
  def to_charlist(float) when is_float(float) do
    :io_lib_format.fwrite_g(float)
  end

  @doc """
  Returns a binary which corresponds to the text representation
  of the given float.
  """
  @spec to_string(float) :: String.t()
  def to_string(float) when is_float(float) do
    IO.iodata_to_binary(:io_lib_format.fwrite_g(float))
  end

  @doc false
  @deprecated "Use Float.to_charlist/1 instead"
  def to_char_list(float), do: Float.to_charlist(float)

  @doc false
  @deprecated "Use :erlang.float_to_list/2 instead"
  def to_char_list(float, options) do
    :erlang.float_to_list(float, expand_compact(options))
  end

  @doc false
  @deprecated "Use :erlang.float_to_binary/2 instead"
  def to_string(float, options) do
    :erlang.float_to_binary(float, expand_compact(options))
  end

  defp invalid_precision_message(precision) do
    "precision #{precision} is out of valid range of #{inspect(@precision_range)}"
  end

  defp expand_compact([{:compact, false} | t]), do: expand_compact(t)
  defp expand_compact([{:compact, true} | t]), do: [:compact | expand_compact(t)]
  defp expand_compact([h | t]), do: [h | expand_compact(t)]
  defp expand_compact([]), do: []
end
